# # spectral.py (self-contained, with demo)

# import numpy as np
# import scipy.sparse as sp
# from scipy.sparse.csgraph import connected_components
# from scipy.sparse.linalg import eigsh

# try:
#     from sklearn.manifold import spectral_embedding
#     from sklearn.decomposition import PCA
#     from sklearn.datasets import make_moons
#     import matplotlib.pyplot as plt
#     HAS_SKLEARN = True
# except Exception:
#     HAS_SKLEARN = False

# try:
#     import pyamg
#     HAS_PYAMG = True
# except Exception:
#     HAS_PYAMG = False


# # ---------------- utilities ----------------

# def _build_W(n: int, ei: np.ndarray, ej: np.ndarray, P: np.ndarray) -> sp.csr_matrix:
#     W = sp.coo_matrix((P.astype(np.float32, copy=False), (ei, ej)), shape=(n, n), dtype=np.float32).tocsr()
#     W = (W + W.T)
#     W.sum_duplicates()
#     W.eliminate_zeros()
#     return W

# def _spectral_small(W: sp.csr_matrix, tol: float = 1e-3, backend: str = "auto") -> np.ndarray:
#     if not HAS_SKLEARN:
#         raise ImportError("scikit-learn is required")
#     solver = ("amg" if HAS_PYAMG and backend in ("auto", "amg") else
#               "lobpcg" if backend in ("auto", "lobpcg") else backend)
#     Y = spectral_embedding(
#         W, n_components=3, eigen_solver=solver,
#         norm_laplacian=True, drop_first=False,
#         random_state=0, eigen_tol=tol,
#     )
#     return Y[:, 1:3].astype(np.float32, copy=False)

# def _component_indices(labels: np.ndarray, c: int) -> np.ndarray:
#     return np.nonzero(labels == c)[0]

# def _center_unit(Y: np.ndarray) -> np.ndarray:
#     Y = Y - Y.mean(axis=0, keepdims=True)
#     r = np.sqrt((Y**2).sum(axis=1).mean()) + 1e-12  # RMS radius
#     return (Y / r).astype(np.float32, copy=False)

# def _pack_on_ring(sizes: np.ndarray, gap: float = 2.0) -> np.ndarray:
#     phi = (np.sqrt(5.0) - 1.0) * np.pi  # golden angle
#     offs = []
#     R0 = gap
#     for t, s in enumerate(sizes):
#         R = R0 + gap * np.sqrt(float(s))
#         offs.append(np.array([R * np.cos(t * phi), R * np.sin(t * phi)], dtype=np.float32))
#     return np.vstack(offs)


# # ---------------- UMAP-style meta-layout helpers ----------------

# def _euclidean_pairwise_distances(M: np.ndarray) -> np.ndarray:
#     """Row-wise Euclidean pairwise distances. M: (C, d) -> (C, C)."""
#     if M.ndim != 2:
#         M = np.atleast_2d(M)
#     ss = np.sum(M * M, axis=1, dtype=np.float64)
#     D2 = np.maximum(0.0, ss[:, None] + ss[None, :] - 2.0 * (M @ M.T))
#     return np.sqrt(D2, dtype=np.float64)


# def _umap_meta_positions(centroids: np.ndarray, dim: int = 2, random_state: int = 0) -> np.ndarray:
#     """Compute UMAP-style meta positions for components.

#     If the number of components is small (<= 2*dim) use a simple ±basis layout.
#     Otherwise, compute a spectral embedding of an affinity built from
#     pairwise centroid distances with A_ij = exp(-||c_i - c_j||^2) as in UMAP's
#     component_layout.
#     """
#     C = centroids.shape[0]
#     if C <= 2 * dim:
#         k = int(np.ceil(C / 2.0))
#         base = np.hstack([np.eye(k, dtype=np.float32), np.zeros((k, dim - k), dtype=np.float32)])
#         meta = np.vstack([base, -base])[:C]
#         return meta.astype(np.float32, copy=False)

#     D = _euclidean_pairwise_distances(centroids).astype(np.float64, copy=False)
#     A = np.exp(-(D ** 2))  # matches UMAP's component_layout affinity

#     if HAS_SKLEARN:
#         try:
#             from sklearn.manifold import SpectralEmbedding
#             meta = SpectralEmbedding(n_components=dim, affinity="precomputed", random_state=random_state).fit_transform(A)
#         except Exception:
#             # very small or degenerate cases: fall back to PCA of centroids
#             from sklearn.decomposition import PCA as _PCA
#             meta = _PCA(n_components=dim, svd_solver="randomized", random_state=random_state).fit_transform(centroids)
#     else:
#         # Fallback without sklearn: center A and take top-2 left singular vectors
#         A0 = A - A.mean(axis=0, keepdims=True)
#         U, S, _ = np.linalg.svd(A0, full_matrices=False)
#         meta = U[:, :dim] * S[:dim]

#     denom = float(np.max(np.abs(meta)) + 1e-12)
#     return (meta / denom).astype(np.float32, copy=False)


# # ---------------- corrected Nyström (normalized adjacency) ----------------
# def _nystrom_init_fixed(W, m=None, tol=1e-3, random_state=0, t_steps=5, lazy=.1):
#     n = W.shape[0]
#     if m is None:
#         m = int(min(n, max(256, int(4 * np.sqrt(n)))))
#     if m >= n:
#         m = max(2, n // 2)
#     if m < 2:
#         return np.zeros((n, 2), dtype=np.float32)

#     rng = np.random.default_rng(random_state)
#     B = np.sort(rng.choice(n, size=m, replace=False))

#     d = (np.asarray(W.sum(axis=1)).ravel().astype(np.float32) + 1e-12)
#     Dmh = 1.0 / np.sqrt(d)
#     A = (sp.diags(Dmh) @ W @ sp.diags(Dmh)).astype(np.float32)

#     if lazy > 0.0:
#         I = sp.eye(n, dtype=np.float32, format="csr")
#         A = (1.0 - lazy) * I + lazy * A

#     A_BB   = A[B][:, B].tocsr()
#     A_cols = A[:, B].tocsr()

#     for _ in range(max(1, t_steps) - 1):
#         A_cols = (A @ A_cols).tocsr().astype(np.float32, copy=False)

#     mB = A_BB.shape[0]
#     k  = max(1, min(3, mB - 1))  # ensure k < mB
#     if k < 1:
#         return np.zeros((n, 2), dtype=np.float32)

#     # tiny diagonal regularization (optional)
#     A_BB = (A_BB + 1e-7 * sp.eye(mB, dtype=np.float32)).astype(np.float64)

#     vals, vecs = eigsh(A_BB, k=k, which="LA", tol=tol, maxiter=200)
#     order = np.argsort(vals)[::-1]
#     vals, vecs = vals[order], vecs[:, order]

#     take = min(2, vecs.shape[1] - 1)
#     if take < 1:
#         return np.zeros((n, 2), dtype=np.float32)

#     U_B   = vecs[:, 1:1+take].astype(np.float32, copy=False)
#     lam   = vals[1:1+take].astype(np.float32, copy=False)
#     U_all = (A_cols @ U_B).astype(np.float32, copy=False) / (lam[None, :] ** max(1, t_steps) + 1e-12)

#     Y = U_all - U_all.mean(axis=0, keepdims=True)
#     Y /= (np.sqrt((Y * Y).sum(axis=1).mean()) + 1e-12)
#     if Y.shape[1] == 1:
#         Y = np.hstack([Y, np.zeros((n, 1), dtype=np.float32)])
#     return Y



# # ---------------- public API ----------------

# def init_spectral(
#     X: np.ndarray, P: np.ndarray, ei: np.ndarray, ej: np.ndarray, n: int,
#     *,
#     init_mode: str = "standard",            # "standard" | "nystrom" | "pca"
#     component_strategy: str = "umap",       # "pack" | "anchor" | "umap"
#     backend: str = "auto",
#     tol: float = 1e-3,
#     random_state = 0
# ) -> np.ndarray:
#     if init_mode == "pca":
#         if not HAS_SKLEARN:
#             raise ImportError("scikit-learn PCA is required for init_mode='pca'")
#         Y = PCA(n_components=2, svd_solver="randomized", random_state=0).fit_transform(X).astype(np.float32, copy=False)
#         scale = X.std(axis=0).mean() / (Y.std() + 1e-12)
#         return (Y * scale).astype(np.float32, copy=False)

#     W = _build_W(n, ei, ej, P)
#     n_cc, labels = connected_components(W, directed=False, return_labels=True)

#     if n_cc == 1:
#         if init_mode == "standard":
#             Y = _spectral_small(W, tol=tol, backend=backend)
#         elif init_mode == "nystrom":
#             # >>> FIX: call the corrected Nyström <<<
#             Y = _nystrom_init_fixed(W, tol=tol, random_state=random_state)
#         else:
#             raise ValueError(f"Unknown init_mode: {init_mode!r}")
#         scale = X.std(axis=0).mean() / (Y.std() + 1e-12)
#         return (Y * scale).astype(np.float32, copy=False)

#     # Multiple components: embed each CC and pack
#     comps, sizes, idx_lists = [], [], []
#     for c in range(n_cc):
#         idx = _component_indices(labels, c)
#         idx_lists.append(idx)
#         Wc = W[idx][:, idx].tocsr()
#         if init_mode == "standard":
#             Yc = _spectral_small(Wc, tol=tol, backend=backend)
#         elif init_mode == "nystrom":
#             # >>> FIX: call the corrected Nyström here too <<<
#             Yc = _nystrom_init_fixed(Wc, tol=tol)
#         else:
#             raise ValueError(f"Unknown init_mode: {init_mode!r}")
#         comps.append(_center_unit(Yc))
#         sizes.append(len(idx))
#     sizes = np.array(sizes, dtype=np.int64)

#     if component_strategy == "anchor":
#         if not HAS_SKLEARN:
#             raise ImportError("scikit-learn PCA is required for component_strategy='anchor'")
#         centroids = np.vstack([X[idx].mean(axis=0) for idx in idx_lists])
#         anchor2d = PCA(n_components=2, svd_solver="randomized", random_state=0).fit_transform(centroids).astype(np.float32, copy=False)
#         anchor2d = _center_unit(anchor2d)
#         offs = anchor2d * (1.5 * np.sqrt(sizes)[:, None].astype(np.float32))
#         Y = np.zeros((n, 2), dtype=np.float32)
#         for (idx, Yc, off) in zip(idx_lists, comps, offs):
#             Y[idx] = Yc + off
#         scale = X.std(axis=0).mean() / (Y.std() + 1e-12)
#         return (Y * scale).astype(np.float32, copy=False)

#     elif component_strategy == "pack":
#         offs = _pack_on_ring(sizes, gap=2.0)
#         Y = np.zeros((n, 2), dtype=np.float32)
#         for (idx, Yc, off) in zip(idx_lists, comps, offs):
#             Y[idx] = Yc + off
#         scale = X.std(axis=0).mean() / (Y.std() + 1e-12)
#         return (Y * scale).astype(np.float32, copy=False)

#     elif component_strategy == "umap":
#         # --- UMAP-style global placement with explicit non-overlap ---
#         centroids = np.vstack([X[idx].mean(axis=0) for idx in idx_lists]).astype(np.float32, copy=False)
#         meta = _umap_meta_positions(centroids, dim=2, random_state=random_state)

#         Y = np.zeros((n, 2), dtype=np.float32)
#         rng = np.random.default_rng(random_state)
#         C = meta.shape[0]
#         for c, (idx, Yc) in enumerate(zip(idx_lists, comps)):
#             if C > 1:
#                 d = np.sqrt(((meta[c][None, :] - meta) ** 2).sum(axis=1))
#                 d = d[d > 0.0]
#                 data_range = float(d.min() * 0.5) if d.size > 0 else 1.0
#             else:
#                 data_range = 1.0

#             # Very small components: random box around the meta position
#             if Yc.shape[0] < 4:
#                 Y[idx] = rng.uniform(low=-data_range, high=data_range, size=(Yc.shape[0], 2)).astype(np.float32) + meta[c]
#                 continue

#             maxabs = float(np.max(np.abs(Yc)) + 1e-12)
#             expansion = data_range / maxabs
#             Y[idx] = (Yc * expansion + meta[c]).astype(np.float32, copy=False)

#         # IMPORTANT: do not apply a global rescale here; it would break the
#         # non-overlap/data_range guarantees encoded by the meta layout.
#         return Y

#     else:
#         raise ValueError(f"Unknown component_strategy: {component_strategy!r}")


# # ---------------- demo ----------------

# if __name__ == "__main__":
#     # DEMO: moons dataset, build a graph, try all inits
#     import sys, os
#     sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
#     from graph_utils.graph_builder import build_weighted_graph
#     from sklearn.preprocessing import StandardScaler

#     np.random.seed(0)
#     X, y = make_moons(n_samples=3000, noise=0.08)
#     X = StandardScaler().fit_transform(X)

#     # Build kNN graph (default settings)
#     k = 10
#     rho, sigmas, ei, ej, P_vals, neigh_idx = build_weighted_graph(X, k=k)

#     modes = ["standard", "nystrom", "pca"]
#     fig, axs = plt.subplots(1, len(modes), figsize=(4 * len(modes), 4), constrained_layout=True)

#     for i, mode in enumerate(modes):
#         Y0 = init_spectral(X, P_vals, ei, ej, n=X.shape[0], init_mode=mode)
#         axs[i].scatter(Y0[:, 0], Y0[:, 1], c=y, cmap="Spectral", s=18, edgecolor="k", linewidth=0.4)
#         axs[i].set_title(f"Init mode: {mode}")
#         axs[i].set_xticks([]); axs[i].set_yticks([])

#     plt.show()

# spectral.py (self-contained, with demo)
# --------------------------------------
# Spectral initializations for graph embeddings with multiple strategies:
# - standard      : sklearn spectral embedding on the graph Laplacian
# - nystrom       : corrected Nyström on normalized adjacency (uniform anchors)
# - nystrompp     : corrected Nyström with k-means++ anchors
# - diffmap       : diffusion maps (scale by lambda^t)
#
# Component placement strategies:
# - anchor        : PCA on component centroids, offsets ~ sqrt(size)
# - pack          : golden-angle ring pack, offsets ~ sqrt(size)
# - umap          : UMAP-style meta spectral embedding of components using
#                   exp(-d^2) affinity + explicit non-overlap scaling
#
# The file includes a self-contained demo using a small kNN graph builder.

from __future__ import annotations

import numpy as np
import scipy.sparse as sp
from scipy.sparse.csgraph import connected_components
from scipy.sparse.linalg import eigsh

# Optional dependencies
try:
    from sklearn.manifold import spectral_embedding, SpectralEmbedding
    from sklearn.decomposition import PCA, TruncatedSVD
    from sklearn.datasets import make_moons
    from sklearn.preprocessing import StandardScaler
    from sklearn.neighbors import NearestNeighbors
    import matplotlib.pyplot as plt
    HAS_SKLEARN = True
except Exception:
    HAS_SKLEARN = False

try:
    import pyamg
    HAS_PYAMG = True
except Exception:
    HAS_PYAMG = False


# ---------------- utilities ----------------

def _build_W(n: int, ei: np.ndarray, ej: np.ndarray, P: np.ndarray) -> sp.csr_matrix:
    """Build symmetric sparse adjacency from edge lists and weights."""
    W = sp.coo_matrix(
        (P.astype(np.float32, copy=False), (ei, ej)),
        shape=(n, n),
        dtype=np.float32,
    ).tocsr()
    W = (W + W.T)
    W.sum_duplicates()
    W.eliminate_zeros()
    return W

def _spectral_small(W: sp.csr_matrix, tol: float = 1e-3, backend: str = "auto") -> np.ndarray:
    """Small-graph spectral embedding (2D) using scikit-learn."""
    if not HAS_SKLEARN:
        raise ImportError("scikit-learn is required for 'standard' spectral init")
    solver = ("amg" if HAS_PYAMG and backend in ("auto", "amg") else
              "lobpcg" if backend in ("auto", "lobpcg") else backend)
    Y = spectral_embedding(
        W, n_components=3, eigen_solver=solver,
        norm_laplacian=True, drop_first=False,
        random_state=0, eigen_tol=tol,
    )
    return Y[:, 1:3].astype(np.float32, copy=False)

def _component_indices(labels: np.ndarray, c: int) -> np.ndarray:
    return np.nonzero(labels == c)[0]

def _center_unit(Y: np.ndarray) -> np.ndarray:
    """Center and scale to unit RMS radius."""
    Y = Y - Y.mean(axis=0, keepdims=True)
    r = np.sqrt((Y**2).sum(axis=1).mean()) + 1e-12  # RMS radius
    return (Y / r).astype(np.float32, copy=False)

def _pack_on_ring(sizes: np.ndarray, gap: float = 2.0) -> np.ndarray:
    """Place components on a golden-angle ring with radius ~ sqrt(size)."""
    phi = (np.sqrt(5.0) - 1.0) * np.pi  # golden angle
    offs = []
    R0 = gap
    for t, s in enumerate(sizes):
        R = R0 + gap * np.sqrt(float(s))
        offs.append(np.array([R * np.cos(t * phi), R * np.sin(t * phi)], dtype=np.float32))
    return np.vstack(offs)


# ---------------- UMAP-style meta-layout helpers ----------------

def _euclidean_pairwise_distances(M: np.ndarray) -> np.ndarray:
    """Row-wise Euclidean pairwise distances. M: (C, d) -> (C, C)."""
    if M.ndim != 2:
        M = np.atleast_2d(M)
    ss = np.sum(M * M, axis=1, dtype=np.float64)
    D2 = np.maximum(0.0, ss[:, None] + ss[None, :] - 2.0 * (M @ M.T))
    return np.sqrt(D2)

def _umap_meta_positions(centroids: np.ndarray, dim: int = 2, random_state: int = 0) -> np.ndarray:
    """Compute UMAP-style meta positions for components.

    If the number of components is small (<= 2*dim) use a simple ±basis layout.
    Otherwise, compute a spectral embedding of an affinity built from
    pairwise centroid distances with A_ij = exp(-||c_i - c_j||^2).
    """
    C = centroids.shape[0]
    if C <= 2 * dim:
        k = int(np.ceil(C / 2.0))
        base = np.hstack([np.eye(k, dtype=np.float32), np.zeros((k, dim - k), dtype=np.float32)])
        meta = np.vstack([base, -base])[:C]
        return meta.astype(np.float32, copy=False)

    D = _euclidean_pairwise_distances(centroids).astype(np.float64, copy=False)
    A = np.exp(-(D ** 2))  # matches UMAP's component_layout affinity

    if HAS_SKLEARN:
        try:
            meta = SpectralEmbedding(
                n_components=dim,
                affinity="precomputed",
                random_state=random_state
            ).fit_transform(A)
        except Exception:
            # very small or degenerate cases: fall back to PCA of centroids
            meta = PCA(n_components=dim, svd_solver="randomized", random_state=random_state)\
                   .fit_transform(centroids)
    else:
        # Fallback without sklearn: centered SVD of A
        A0 = A - A.mean(axis=0, keepdims=True)
        U, S, _ = np.linalg.svd(A0, full_matrices=False)
        meta = U[:, :dim] * S[:dim]

    denom = float(np.max(np.abs(meta)) + 1e-12)
    return (meta / denom).astype(np.float32, copy=False)


# ---------------- advanced sampling & diffusion helpers ----------------

def _kmeanspp_indices(X: np.ndarray, m: int, rng: np.random.Generator) -> np.ndarray:
    """Lightweight k-means++ seeding to pick m representative anchor indices.
    If m >= n it returns all indices sorted.
    """
    n = X.shape[0]
    if m >= n:
        return np.arange(n, dtype=np.int64)
    first = int(rng.integers(n))
    centers = [first]
    D2 = np.sum((X - X[first]) ** 2, axis=1, dtype=np.float64)
    for _ in range(1, m):
        s = D2.sum() + 1e-12
        probs = D2 / s
        j = int(rng.choice(n, p=probs))
        centers.append(j)
        dnew = np.sum((X - X[j]) ** 2, axis=1, dtype=np.float64)
        D2 = np.minimum(D2, dnew)
    return np.sort(np.array(centers, dtype=np.int64))

def _diffusion_map_small(W: sp.csr_matrix, t: float = 1.0, lazy: float = 0.1, tol: float = 1e-3) -> np.ndarray:
    """2-D diffusion maps embedding for a single connected component.

    Build symmetric normalized adjacency A = D^{-1/2} W D^{-1/2} (optionally lazified),
    compute top-3 eigenpairs, drop the trivial first, and scale by lambda^t.
    """
    n = W.shape[0]
    if n < 2:
        return np.zeros((n, 2), dtype=np.float32)

    d = (np.asarray(W.sum(axis=1)).ravel().astype(np.float64) + 1e-12)
    Dmh = 1.0 / np.sqrt(d)
    A = (sp.diags(Dmh) @ W @ sp.diags(Dmh)).astype(np.float64)

    if lazy > 0.0:
        I = sp.eye(n, dtype=np.float64, format="csr")
        A = (1.0 - lazy) * I + lazy * A

    k = min(3, n)
    if k <= 1:
        return np.zeros((n, 2), dtype=np.float32)

    vals, vecs = eigsh(A, k=k, which="LA", tol=tol, maxiter=max(5 * n, 1000))
    order = np.argsort(vals)[::-1]
    vals, vecs = vals[order], vecs[:, order]

    take = min(2, vecs.shape[1] - 1)
    if take < 1:
        return np.zeros((n, 2), dtype=np.float32)

    U = vecs[:, 1:1 + take].astype(np.float64, copy=False)
    lam = np.clip(vals[1:1 + take].astype(np.float64, copy=False), 0.0, 1.0)
    if float(t) != 0.0:
        U = U * (lam[None, :] ** float(t))

    Y = U - U.mean(axis=0, keepdims=True)
    r = np.sqrt((Y * Y).sum(axis=1).mean()) + 1e-12
    Y = (Y / r).astype(np.float32, copy=False)
    if Y.shape[1] == 1:
        Y = np.hstack([Y, np.zeros((n, 1), dtype=np.float32)])
    return Y


# ---------------- corrected Nyström (normalized adjacency) ----------------
def _nystrom_init_fixed(
    W,
    m: int | None = None,
    tol: float = 1e-3,
    random_state: int = 0,
    t_steps: int = 5,
    lazy: float = .1,
    X: np.ndarray | None = None,
    anchors: str = "uniform"   # "uniform" | "kmeans++" | "deg"
) -> np.ndarray:
    """Nyström approximation on normalized adjacency with optional smarter anchors."""
    n = W.shape[0]
    if m is None:
        m = int(min(n, max(256, int(4 * np.sqrt(n)))))
    if m >= n:
        m = max(2, n // 2)
    if m < 2:
        return np.zeros((n, 2), dtype=np.float32)

    rng = np.random.default_rng(random_state)
    if anchors == "kmeans++" and X is not None:
        B = _kmeanspp_indices(X, m, rng)
    elif anchors == "deg":
        deg = (np.asarray(W.sum(axis=1)).ravel().astype(np.float64) + 1e-12)
        p = deg / deg.sum()
        B = np.sort(rng.choice(n, size=m, replace=False, p=p))
    else:
        B = np.sort(rng.choice(n, size=m, replace=False))

    d = (np.asarray(W.sum(axis=1)).ravel().astype(np.float32) + 1e-12)
    Dmh = 1.0 / np.sqrt(d)
    A = (sp.diags(Dmh) @ W @ sp.diags(Dmh)).astype(np.float32)

    if lazy > 0.0:
        I = sp.eye(n, dtype=np.float32, format="csr")
        A = (1.0 - lazy) * I + lazy * A

    A_BB   = A[B][:, B].tocsr()
    A_cols = A[:, B].tocsr()

    for _ in range(max(1, t_steps) - 1):
        A_cols = (A @ A_cols).tocsr().astype(np.float32, copy=False)

    mB = A_BB.shape[0]
    k  = max(1, min(3, mB - 1))  # ensure k < mB
    if k < 1:
        return np.zeros((n, 2), dtype=np.float32)

    # tiny diagonal regularization (optional)
    A_BB = (A_BB + 1e-7 * sp.eye(mB, dtype=np.float32)).astype(np.float64)

    vals, vecs = eigsh(A_BB, k=k, which="LA", tol=tol, maxiter=200)
    order = np.argsort(vals)[::-1]
    vals, vecs = vals[order], vecs[:, order]

    take = min(2, vecs.shape[1] - 1)
    if take < 1:
        return np.zeros((n, 2), dtype=np.float32)

    U_B   = vecs[:, 1:1 + take].astype(np.float32, copy=False)
    lam   = vals[1:1 + take].astype(np.float32, copy=False)
    U_all = (A_cols @ U_B).astype(np.float32, copy=False) / (lam[None, :] ** max(1, t_steps) + 1e-12)

    Y = U_all - U_all.mean(axis=0, keepdims=True)
    Y /= (np.sqrt((Y * Y).sum(axis=1).mean()) + 1e-12)
    if Y.shape[1] == 1:
        Y = np.hstack([Y, np.zeros((n, 1), dtype=np.float32)])
    return Y


# ---------------- public API ----------------

def init_spectral(
    X: np.ndarray, P: np.ndarray, ei: np.ndarray, ej: np.ndarray, n: int,
    *,
    init_mode: str = "standard",            # "standard" | "nystrom" | "nystrompp" | "diffmap" | "pca"
    component_strategy: str = "pack",     # "pack" | "anchor" | "umap"
    backend: str = "auto",
    tol: float = 1e-3,
    diffusion_t: float = 1.0,
    random_state: int = 0
) -> np.ndarray:
    """Initialize 2D positions for an n-node graph built from (ei, ej, P) over X."""
    # Build symmetric adjacency
    W = _build_W(n, ei, ej, P)
    n_cc, labels = connected_components(W, directed=False, return_labels=True)

    # Single connected component
    if n_cc == 1:
        if init_mode == "standard":
            Y = _spectral_small(W, tol=tol, backend=backend)
        elif init_mode == "nystrom":
            Y = _nystrom_init_fixed(W, tol=tol, random_state=random_state)
        elif init_mode == "nystrompp":
            Y = _nystrom_init_fixed(W, tol=tol, random_state=random_state, X=X, anchors="kmeans++")
        elif init_mode == "diffmap":
            Y = _diffusion_map_small(W, t=diffusion_t, tol=tol)
        elif init_mode == "pca":
            if not HAS_SKLEARN:
                raise ImportError("scikit-learn PCA is required for init_mode='pca'")
            Y = PCA(n_components=2, svd_solver="randomized", random_state=0)\
                .fit_transform(X).astype(np.float32, copy=False)
        else:
            raise ValueError(f"Unknown init_mode: {init_mode!r}")
        scale = X.std(axis=0).mean() / (Y.std() + 1e-12)
        return (Y * scale).astype(np.float32, copy=False)

    # Multiple components: embed each CC then place components
    comps, sizes, idx_lists = [], [], []
    for c in range(n_cc):
        idx = _component_indices(labels, c)
        idx_lists.append(idx)
        Wc = W[idx][:, idx].tocsr()
        if init_mode == "standard":
            Yc = _spectral_small(Wc, tol=tol, backend=backend)
        elif init_mode == "nystrom":
            Yc = _nystrom_init_fixed(Wc, tol=tol)
        elif init_mode == "nystrompp":
            Yc = _nystrom_init_fixed(Wc, tol=tol, X=X[idx], anchors="kmeans++")
        elif init_mode == "diffmap":
            Yc = _diffusion_map_small(Wc, t=diffusion_t, tol=tol)
        elif init_mode == "pca":
            if not HAS_SKLEARN:
                raise ImportError("scikit-learn PCA is required for init_mode='pca'")
            if X[idx].shape[0] >= 2:
                Yc = PCA(n_components=2, svd_solver="randomized", random_state=0)\
                        .fit_transform(X[idx]).astype(np.float32, copy=False)
            else:
                Yc = np.zeros((X[idx].shape[0], 2), dtype=np.float32)
        else:
            raise ValueError(f"Unknown init_mode: {init_mode!r}")
        comps.append(_center_unit(Yc))
        sizes.append(len(idx))
    sizes = np.array(sizes, dtype=np.int64)

    if component_strategy == "anchor":
        if not HAS_SKLEARN:
            raise ImportError("scikit-learn PCA is required for component_strategy='anchor'")
        centroids = np.vstack([X[idx].mean(axis=0) for idx in idx_lists])
        anchor2d = PCA(n_components=2, svd_solver="randomized", random_state=0)\
                   .fit_transform(centroids).astype(np.float32, copy=False)
        anchor2d = _center_unit(anchor2d)
        offs = anchor2d * (1.5 * np.sqrt(sizes)[:, None].astype(np.float32))
        Y = np.zeros((n, 2), dtype=np.float32)
        for (idx, Yc, off) in zip(idx_lists, comps, offs):
            Y[idx] = Yc + off
        scale = X.std(axis=0).mean() / (Y.std() + 1e-12)
        return (Y * scale).astype(np.float32, copy=False)

    elif component_strategy == "pack":
        offs = _pack_on_ring(sizes, gap=2.0)
        Y = np.zeros((n, 2), dtype=np.float32)
        for (idx, Yc, off) in zip(idx_lists, comps, offs):
            Y[idx] = Yc + off
        scale = X.std(axis=0).mean() / (Y.std() + 1e-12)
        return (Y * scale).astype(np.float32, copy=False)

    elif component_strategy == "umap":
        # UMAP-style global placement with explicit non-overlap
        centroids = np.vstack([X[idx].mean(axis=0) for idx in idx_lists]).astype(np.float32, copy=False)
        meta = _umap_meta_positions(centroids, dim=2, random_state=random_state)

        Y = np.zeros((n, 2), dtype=np.float32)
        rng = np.random.default_rng(random_state)
        C = meta.shape[0]
        for c, (idx, Yc) in enumerate(zip(idx_lists, comps)):
            if C > 1:
                d = np.sqrt(((meta[c][None, :] - meta) ** 2).sum(axis=1))
                d = d[d > 0.0]
                data_range = float(d.min() * 0.5) if d.size > 0 else 1.0
            else:
                data_range = 1.0

            # Very small components: random box around the meta position
            if Yc.shape[0] < 4:
                Y[idx] = rng.uniform(low=-data_range, high=data_range, size=(Yc.shape[0], 2)).astype(np.float32) + meta[c]
                continue

            maxabs = float(np.max(np.abs(Yc)) + 1e-12)
            expansion = data_range / maxabs
            Y[idx] = (Yc * expansion + meta[c]).astype(np.float32, copy=False)

        # Do not apply global rescale; it breaks non-overlap/data_range guarantees.
        return Y

    else:
        raise ValueError(f"Unknown component_strategy: {component_strategy!r}")


# ---------------- simple kNN graph builder (self-contained) ----------------

def build_weighted_graph(
    X: np.ndarray,
    k: int = 10,
    *,
    metric: str = "euclidean",
    local_scale: str = "knn",   # "knn" uses k-th neighbor distance as sigma_i
    mutual: bool = False,
    random_state: int = 0,
):
    """Self-contained kNN graph with simple heat-kernel weights.

    Returns:
        rho: zeros (placeholder; kept for API similarity)
        sigmas: per-point local scale (sigma_i)
        ei, ej: edge endpoints (int64)
        P_vals: edge weights (float32)
        neigh_idx: neighbor indices per point (list of arrays)
    """
    n = X.shape[0]
    rng = np.random.default_rng(random_state)

    if HAS_SKLEARN:
        nbrs = NearestNeighbors(n_neighbors=min(k + 1, n), metric=metric)
        nbrs.fit(X)
        dists, inds = nbrs.kneighbors(X, return_distance=True)
        # drop self (assumed first)
        dists = dists[:, 1:]
        inds = inds[:, 1:]
    else:
        # brute-force (O(n^2)) for demo-scale data
        if metric != "euclidean":
            raise ValueError("Only 'euclidean' metric is supported without scikit-learn")
        # compute squared norms
        ss = np.sum(X * X, axis=1, dtype=np.float64)
        D2 = np.maximum(0.0, ss[:, None] + ss[None, :] - 2.0 * (X @ X.T))
        np.fill_diagonal(D2, np.inf)
        # get k smallest per row
        idx_part = np.argpartition(D2, kth=min(k, n - 1) - 1, axis=1)[:, :k]
        # gather and sort
        rows = np.arange(n)[:, None]
        dists = np.sqrt(D2[rows, idx_part])
        order = np.argsort(dists, axis=1)
        inds = idx_part[rows, order]
        dists = dists[rows, order]

    # local scales: sigma_i = distance to k-th neighbor (robust-ish)
    sigmas = dists[:, -1].astype(np.float64) + 1e-12
    if local_scale == "knn":
        sigma = sigmas
    else:
        sigma = sigmas  # extend here for other strategies if desired

    # Build edge list and heat-kernel weights using symmetric local scale
    ei_list, ej_list, P_list = [], [], []
    neigh_idx = []
    for i in range(n):
        js = inds[i]
        neigh_idx.append(js)
        si = sigma[i]
        # weights with symmetric local scale
        wij = np.exp(-(dists[i] ** 2) / (si * sigma[js] + 1e-12))
        # optionally keep only mutual neighbors
        if mutual:
            # keep j if i is among j's neighbors
            keep = [j for j in js if i in inds[j]]
            if len(keep) == 0:
                keep = js  # fallback
            js = np.asarray(keep, dtype=np.int64)
            # recompute weights subset
            mask = np.isin(inds[i], js)
            wij = wij[mask]
        # accumulate
        ei_list.append(np.full(js.shape, i, dtype=np.int64))
        ej_list.append(js.astype(np.int64, copy=False))
        P_list.append(wij.astype(np.float32, copy=False))

    ei = np.concatenate(ei_list)
    ej = np.concatenate(ej_list)
    P_vals = np.concatenate(P_list)
    rho = np.zeros(n, dtype=np.float32)  # placeholder to mirror prior API
    return rho, sigma.astype(np.float32), ei, ej, P_vals, neigh_idx


# ---------------- demo ----------------

if __name__ == "__main__":
    if not HAS_SKLEARN:
        print("Demo requires scikit-learn and matplotlib; please install them to run the demo.")
    else:
        np.random.seed(0)
        X, y = make_moons(n_samples=5000, noise=0.08)
        X = StandardScaler().fit_transform(X)

        # Build kNN graph (self-contained)
        k = 10
        rho, sigmas, ei, ej, P_vals, neigh_idx = build_weighted_graph(X, k=k, mutual=False)

        modes = ["standard", "nystrom", "nystrompp", "diffmap", "pca"]
        comp_strategy = "umap"  # try: "umap" (recommended), "anchor", "pack"
        fig, axs = plt.subplots(1, len(modes), figsize=(4 * len(modes), 4), constrained_layout=True)

        for i, mode in enumerate(modes):
            Y0 = init_spectral(
                X, P_vals, ei, ej, n=X.shape[0],
                init_mode=mode,
                component_strategy=comp_strategy,
                diffusion_t=5.0,          # try 3–10 for diffmap
                random_state=0
            )
            axs[i].scatter(Y0[:, 0], Y0[:, 1], c=y, cmap="Spectral", s=12, edgecolor="k", linewidth=0.25)
            axs[i].set_title(f"{mode} + {comp_strategy}")
            axs[i].set_xticks([]); axs[i].set_yticks([])

        plt.show()